package org.carrot2.elasticsearch;
import static org.carrot2.elasticsearch.LoggerUtils.emitErrorResponse;
import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.rest.RestRequest.Method.POST;
import static org.elasticsearch.rest.RestRequest.Method.GET;
import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.TimeUnit;
import org.carrot2.core.Cluster;
import org.carrot2.core.Controller;
import org.carrot2.core.Document;
import org.carrot2.core.LanguageCode;
import org.carrot2.core.ProcessingException;
import org.carrot2.core.ProcessingResult;
import org.carrot2.core.attribute.CommonAttributesDescriptor;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.Action;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestBuilder;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.TransportAction;
import org.elasticsearch.client.ElasticsearchClient;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.index.query.QueryParseContext;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.search.RestSearchAction;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
import org.elasticsearch.search.internal.InternalSearchResponse;
import org.elasticsearch.search.profile.SearchProfileShardResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportService;
/**
* Perform clustering of search results.
*/
public class ClusteringAction
extends Action<ClusteringAction.ClusteringActionRequest,
ClusteringAction.ClusteringActionResponse,
ClusteringAction.ClusteringActionRequestBuilder> {
/* Action name. */
public static final String NAME = "clustering/cluster";
/* Reusable singleton. */
public static final ClusteringAction INSTANCE = new ClusteringAction();
private ClusteringAction() {
super(NAME);
}
@Override
public ClusteringActionResponse newResponse() {
return new ClusteringActionResponse();
}
@Override
public ClusteringActionRequestBuilder newRequestBuilder(ElasticsearchClient client) {
return new ClusteringActionRequestBuilder(client);
}
/**
* An {@link ActionRequest} for {@link ClusteringAction}.
*/
public static class ClusteringActionRequest extends ActionRequest {
private SearchRequest searchRequest;
private String queryHint;
private List<FieldMappingSpec> fieldMapping = new ArrayList<>();
private String algorithm;
private int maxHits = Integer.MAX_VALUE;
private Map<String, Object> attributes;
/**
* Set the {@link SearchRequest} to use for fetching documents to be clustered.
* The search request must fetch enough documents for clustering to make sense
* (set <code>size</code> appropriately).
* @param searchRequest search request to set
* @return same builder instance
*/
public ClusteringActionRequest setSearchRequest(SearchRequest searchRequest) {
this.searchRequest = searchRequest;
return this;
}
/**
* @see #setSearchRequest(SearchRequest)
*/
public ClusteringActionRequest setSearchRequest(SearchRequestBuilder builder) {
return setSearchRequest(builder.request());
}
public SearchRequest getSearchRequest() {
return searchRequest;
}
/**
* @param queryHint A set of terms which correspond to the query. This hint helps the
* clustering algorithm to avoid trivial clusters around the query terms. Typically the query
* terms hint will be identical to what the user typed in the search box.
* <p>
* The hint may be an empty string but must not be <code>null</code>.
* @return same builder instance
*/
public ClusteringActionRequest setQueryHint(String queryHint) {
this.queryHint = queryHint;
return this;
}
/**
* @see #setQueryHint(String)
*/
public String getQueryHint() {
return queryHint;
}
/**
* Sets the identifier of the clustering algorithm to use. If <code>null</code>, the default
* algorithm will be used (depending on what's available).
*/
public ClusteringActionRequest setAlgorithm(String algorithm) {
this.algorithm = algorithm;
return this;
}
/**
* @see #setAlgorithm
*/
public String getAlgorithm() {
return algorithm;
}
/**
* @see #getIncludeHits
* @deprecated Use {@link #setMaxHits} and set it to zero instead.
*/
@Deprecated()
public boolean getIncludeHits() {
return maxHits > 0;
}
/**
* Sets whether to include hits with clustering results. If only cluster labels
* are needed the hits may be omitted to save bandwidth.
*
* @deprecated Use {@link #setMaxHits} instead.
*/
@Deprecated()
public ClusteringActionRequest setIncludeHits(boolean includeHits) {
if (includeHits) {
setMaxHits(Integer.MAX_VALUE);
} else {
setMaxHits(0);
}
return this;
}
/**
* Sets the maximum number of hits to return with the response. Setting this
* value to zero will only return clusters, without any hits (can be used
* to save bandwidth if only cluster labels are needed).
* <p>
* Set to {@link Integer#MAX_VALUE} to include all the hits.
*/
public void setMaxHits(int maxHits) {
assert maxHits >= 0;
this.maxHits = maxHits;
}
/**
* Sets {@link #setMaxHits(int)} from a string. An empty string or null means
* all hits should be included.
*/
public void setMaxHits(String value) {
if (value == null || value.trim().isEmpty()) {
setMaxHits(Integer.MAX_VALUE);
} else {
setMaxHits(Integer.parseInt(value));
}
}
/**
* Returns the maximum number of hits to be returned as part of the response.
* If equal to {@link Integer#MAX_VALUE}, then all hits will be returned.
*/
public int getMaxHits() {
return maxHits;
}
/**
* Sets a map of runtime override attributes for clustering algorithms.
*/
public ClusteringActionRequest setAttributes(Map<String, Object> map) {
this.attributes = map;
return this;
}
/**
* @see #setAttributes(Map)
*/
public Map<String, Object> getAttributes() {
return attributes;
}
/**
* Parses some {@link org.elasticsearch.common.xcontent.XContent} and fills in the request.
*/
@SuppressWarnings("unchecked")
public void source(BytesReference source, XContentType xContentType, NamedXContentRegistry xContentRegistry) {
if (source == null || source.length() == 0) {
return;
}
try (XContentParser parser = XContentHelper.createParser(xContentRegistry, source, xContentType)) {
// TODO: we should avoid reparsing search_request here
// but it's terribly difficult to slice the underlying byte
// buffer to get just the search request.
Map<String, Object> asMap = parser.mapOrdered();
String queryHint = (String) asMap.get("query_hint");
if (queryHint != null) {
setQueryHint(queryHint);
}
Map<String, List<String>> fieldMapping = (Map<String, List<String>>) asMap.get("field_mapping");
if (fieldMapping != null) {
parseFieldSpecs(fieldMapping);
}
String algorithm = (String) asMap.get("algorithm");
if (algorithm != null) {
setAlgorithm(algorithm);
}
Map<String, Object> attributes = (Map<String, Object>) asMap.get("attributes");
if (attributes != null) {
setAttributes(attributes);
}
Map<String, ?> searchRequestMap = (Map<String, ?>) asMap.get("search_request");
if (searchRequestMap != null) {
if (this.searchRequest == null) {
searchRequest = new SearchRequest();
}
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON).map(searchRequestMap);
XContentParser searchXParser = XContentFactory.xContent(XContentType.JSON)
.createParser(xContentRegistry, builder.bytes());
QueryParseContext parseContext = new QueryParseContext(searchXParser);
SearchSourceBuilder searchSourceBuilder =
SearchSourceBuilder.fromXContent(parseContext);
searchRequest.source(searchSourceBuilder);
}
Object includeHits = asMap.get("include_hits");
if (includeHits != null) {
Loggers.getLogger(getClass()).warn("Request used deprecated 'include_hits' parameter.");
setIncludeHits(Boolean.parseBoolean(includeHits.toString()));
}
Object maxHits = asMap.get("max_hits");
if (maxHits != null) {
setMaxHits(maxHits.toString());
}
} catch (Exception e) {
String sSource = "_na_";
try {
sSource = XContentHelper.convertToJson(source, false, false, xContentType);
} catch (Throwable e1) {
// ignore
}
throw new ClusteringException("Failed to parse source [" + sSource + "]", e);
}
}
private void parseFieldSpecs(Map<String, List<String>> fieldSpecs) {
for (Map.Entry<String, List<String>> e : fieldSpecs.entrySet()) {
LogicalField logicalField = LogicalField.valueOfCaseInsensitive(e.getKey());
if (logicalField != null) {
for (String fieldSpec : e.getValue()) {
addFieldMappingSpec(fieldSpec, logicalField);
}
}
}
}
/**
* Map a hit's field to a logical section of a document to be clustered (title, content or URL).
*
* @see LogicalField
*/
public ClusteringActionRequest addFieldMapping(String fieldName, LogicalField logicalField) {
fieldMapping.add(new FieldMappingSpec(fieldName, logicalField, FieldSource.FIELD));
return this;
}
/**
* Map a hit's source field (field unpacked from the <code>_source</code> document)
* to a logical section of a document to be clustered (title, content or URL).
*
* @see LogicalField
*/
public ClusteringActionRequest addSourceFieldMapping(String sourceFieldName, LogicalField logicalField) {
fieldMapping.add(new FieldMappingSpec(sourceFieldName, logicalField, FieldSource.SOURCE));
return this;
}
/**
* Map a hit's highligted field (fragments of the original field) to a logical section
* of a document to be clustered. This may be used to decrease the amount of information
* passed to the clustering engine but also to "focus" the clustering engine on the context
* of the query.
*/
public ClusteringActionRequest addHighlightedFieldMapping(String fieldName, LogicalField logicalField) {
fieldMapping.add(new FieldMappingSpec(fieldName, logicalField, FieldSource.HIGHLIGHT));
return this;
}
/**
* Add a (valid!) field mapping specification to a logical field.
*
* @see FieldSource
*/
public ClusteringActionRequest addFieldMappingSpec(String fieldSpec, LogicalField logicalField) {
FieldSource.ParsedFieldSource pfs = FieldSource.parseSpec(fieldSpec);
if (pfs.source != null) {
switch (pfs.source) {
case HIGHLIGHT:
addHighlightedFieldMapping(pfs.fieldName, logicalField);
break;
case FIELD:
addFieldMapping(pfs.fieldName, logicalField);
break;
case SOURCE:
addSourceFieldMapping(pfs.fieldName, logicalField);
break;
default:
throw new RuntimeException();
}
}
if (pfs.source == null) {
throw new ElasticsearchException("Field mapping specification must contain a " +
" valid source prefix for the field source: " + fieldSpec);
}
return this;
}
/**
* Access to prepared field mapping.
*/
List<FieldMappingSpec> getFieldMapping() {
return fieldMapping;
}
@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = null;
if (searchRequest == null) {
validationException = addValidationError("No delegate search request",
validationException);
}
if (queryHint == null) {
validationException = addValidationError("query hint may be empty but must not be null.",
validationException);
}
if (fieldMapping.isEmpty()) {
validationException = addValidationError("At least one field should be mapped to a logical document field.",
validationException);
}
ActionRequestValidationException ex = searchRequest.validate();
if (ex != null) {
if (validationException == null) {
validationException = new ActionRequestValidationException();
}
validationException.addValidationErrors(ex.validationErrors());
}
return validationException;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
assert searchRequest != null;
this.searchRequest.writeTo(out);
out.writeOptionalString(queryHint);
out.writeOptionalString(algorithm);
out.writeInt(maxHits);
out.writeVInt(fieldMapping.size());
for (FieldMappingSpec spec : fieldMapping) {
spec.writeTo(out);
}
boolean hasAttributes = (attributes != null);
out.writeBoolean(hasAttributes);
if (hasAttributes) {
out.writeMap(attributes);
}
}
@Override
public void readFrom(StreamInput in) throws IOException {
SearchRequest searchRequest = new SearchRequest();
searchRequest.readFrom(in);
this.searchRequest = searchRequest;
this.queryHint = in.readOptionalString();
this.algorithm = in.readOptionalString();
this.maxHits = in.readInt();
int count = in.readVInt();
while (count-- > 0) {
FieldMappingSpec spec = new FieldMappingSpec();
spec.readFrom(in);
fieldMapping.add(spec);
}
boolean hasAttributes = in.readBoolean();
if (hasAttributes) {
attributes = in.readMap();
}
}
}
/**
* An {@link ActionRequestBuilder} for {@link ClusteringAction}.
*/
public static class ClusteringActionRequestBuilder
extends ActionRequestBuilder<ClusteringActionRequest,
ClusteringActionResponse,
ClusteringActionRequestBuilder> {
public ClusteringActionRequestBuilder(ElasticsearchClient client) {
super(client, ClusteringAction.INSTANCE, new ClusteringActionRequest());
}
public ClusteringActionRequestBuilder setSearchRequest(SearchRequestBuilder builder) {
super.request.setSearchRequest(builder);
return this;
}
public ClusteringActionRequestBuilder setSearchRequest(SearchRequest searchRequest) {
super.request.setSearchRequest(searchRequest);
return this;
}
public ClusteringActionRequestBuilder setQueryHint(String queryHint) {
if (queryHint == null) {
throw new IllegalArgumentException("Query hint may be empty but must not be null.");
}
super.request.setQueryHint(queryHint);
return this;
}
public ClusteringActionRequestBuilder setAlgorithm(String algorithm) {
super.request.setAlgorithm(algorithm);
return this;
}
public ClusteringActionRequestBuilder setSource(BytesReference content,
XContentType xContentType,
NamedXContentRegistry xContentRegistry) {
super.request.source(content, xContentType, xContentRegistry);
return this;
}
/**
* @deprecated Use {@link #setMaxHits} instead.
*/
@Deprecated()
public ClusteringActionRequestBuilder setIncludeHits(String includeHits) {
if (includeHits != null)
super.request.setIncludeHits(Boolean.parseBoolean(includeHits));
else
super.request.setIncludeHits(true);
return this;
}
public ClusteringActionRequestBuilder setMaxHits(int maxHits) {
super.request.setMaxHits(maxHits);
return this;
}
public ClusteringActionRequestBuilder setMaxHits(String maxHits) {
super.request.setMaxHits(maxHits);
return this;
}
public ClusteringActionRequestBuilder addAttributes(Map<String, Object> attributes) {
if (super.request.getAttributes() == null) {
super.request.setAttributes(new HashMap<String, Object>());
}
super.request.getAttributes().putAll(attributes);
return this;
}
public ClusteringActionRequestBuilder addAttribute(String key, Object value) {
HashMap<String, Object> tmp = new HashMap<String, Object>();
tmp.put(key, value);
return addAttributes(tmp);
}
public ClusteringActionRequestBuilder setAttributes(Map<String, Object> attributes) {
super.request.setAttributes(attributes);
return this;
}
public ClusteringActionRequestBuilder addFieldMapping(String fieldName, LogicalField logicalField) {
super.request.addFieldMapping(fieldName, logicalField);
return this;
}
public ClusteringActionRequestBuilder addSourceFieldMapping(String fieldName, LogicalField logicalField) {
super.request.addSourceFieldMapping(fieldName, logicalField);
return this;
}
public ClusteringActionRequestBuilder addHighlightedFieldMapping(String fieldName, LogicalField logicalField) {
super.request.addHighlightedFieldMapping(fieldName, logicalField);
return this;
}
public ClusteringActionRequestBuilder addFieldMappingSpec(String fieldSpec, LogicalField logicalField) {
super.request.addFieldMappingSpec(fieldSpec, logicalField);
return this;
}
}
/**
* An {@link ActionResponse} for {@link ClusteringAction}.
*/
public static class ClusteringActionResponse extends ActionResponse implements ToXContent {
/**
* Clustering-related response fields.
*/
static final class Fields {
static final String SEARCH_RESPONSE = "search_response";
static final String CLUSTERS = "clusters";
static final String INFO = "info";
// from SearchResponse
static final String _SCROLL_ID = "_scroll_id";
static final String _SHARDS = "_shards";
static final String TOTAL = "total";
static final String SUCCESSFUL = "successful";
static final String FAILED = "failed";
static final String FAILURES = "failures";
static final String STATUS = "status";
static final String INDEX = "index";
static final String SHARD = "shard";
static final String REASON = "reason";
static final String TOOK = "took";
static final String TIMED_OUT = "timed_out";
/**
* {@link Fields#INFO} keys.
*/
static final class Info {
public static final String ALGORITHM = "algorithm";
public static final String SEARCH_MILLIS = "search-millis";
public static final String CLUSTERING_MILLIS = "clustering-millis";
public static final String TOTAL_MILLIS = "total-millis";
public static final String INCLUDE_HITS = "include-hits";
public static final String MAX_HITS = "max-hits";
}
}
private SearchResponse searchResponse;
private DocumentGroup[] topGroups;
private Map<String, String> info;
ClusteringActionResponse() {
}
public ClusteringActionResponse(
SearchResponse searchResponse,
DocumentGroup[] topGroups,
Map<String, String> info) {
this.searchResponse = Preconditions.checkNotNull(searchResponse);
this.topGroups = Preconditions.checkNotNull(topGroups);
this.info = Collections.unmodifiableMap(Preconditions.checkNotNull(info));
}
public SearchResponse getSearchResponse() {
return searchResponse;
}
public DocumentGroup[] getDocumentGroups() {
return topGroups;
}
public Map<String, String> getInfo() {
return info;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params)
throws IOException {
if (searchResponse != null) {
searchResponse.innerToXContent(builder, ToXContent.EMPTY_PARAMS);
}
builder.startArray(Fields.CLUSTERS);
if (topGroups != null) {
for (DocumentGroup group : topGroups) {
group.toXContent(builder, params);
}
}
builder.endArray();
builder.field(Fields.INFO, info);
return builder;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
boolean hasSearchResponse = searchResponse != null;
out.writeBoolean(hasSearchResponse);
if (hasSearchResponse) {
this.searchResponse.writeTo(out);
}
out.writeVInt(topGroups == null ? 0 : topGroups.length);
if (topGroups != null) {
for (DocumentGroup group : topGroups) {
group.writeTo(out);
}
}
out.writeVInt(info == null ? 0 : info.size());
if (info != null) {
for (Map.Entry<String, String> e : info.entrySet()) {
out.writeOptionalString(e.getKey());
out.writeOptionalString(e.getValue());
}
}
}
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
boolean hasSearchResponse = in.readBoolean();
if (hasSearchResponse) {
this.searchResponse = new SearchResponse();
this.searchResponse.readFrom(in);
}
int documentGroupsCount = in.readVInt();
topGroups = new DocumentGroup[documentGroupsCount];
for (int i = 0; i < documentGroupsCount; i++) {
DocumentGroup group = new DocumentGroup();
group.readFrom(in);
topGroups[i] = group;
}
int entries = in.readVInt();
info = new LinkedHashMap<>();
for (int i = 0; i < entries; i++) {
info.put(in.readOptionalString(), in.readOptionalString());
}
}
@Override
public String toString() {
return ToString.objectToJson(this);
}
}
/**
* A {@link TransportAction} for {@link ClusteringAction}.
*/
public static class TransportClusteringAction
extends TransportAction<ClusteringAction.ClusteringActionRequest,
ClusteringAction.ClusteringActionResponse> {
private final Set<String> langCodeWarnings = new CopyOnWriteArraySet<>();
private final TransportSearchAction searchAction;
private final ControllerSingleton controllerSingleton;
@Inject
public TransportClusteringAction(Settings settings,
ThreadPool threadPool,
TransportService transportService,
TransportSearchAction searchAction,
ControllerSingleton controllerSingleton,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
NamedXContentRegistry xContentRegistry) {
super(settings,
ClusteringAction.NAME,
threadPool,
actionFilters,
indexNameExpressionResolver,
transportService.getTaskManager());
this.searchAction = searchAction;
this.controllerSingleton = controllerSingleton;
transportService.registerRequestHandler(
ClusteringAction.NAME,
ClusteringActionRequest::new,
ThreadPool.Names.SAME,
new TransportHandler());
}
@Override
protected void doExecute(final ClusteringActionRequest clusteringRequest,
final ActionListener<ClusteringActionResponse> listener) {
final long tsSearchStart = System.nanoTime();
searchAction.execute(clusteringRequest.getSearchRequest(), new ActionListener<SearchResponse>() {
@Override
public void onFailure(Exception e) {
listener.onFailure(e);
}
@Override
public void onResponse(SearchResponse response) {
final long tsSearchEnd = System.nanoTime();
List<String> algorithmComponentIds = controllerSingleton.getAlgorithms();
String algorithmId = clusteringRequest.getAlgorithm();
if (algorithmId == null) {
algorithmId = algorithmComponentIds.get(0);
} else {
if (!algorithmComponentIds.contains(algorithmId)) {
listener.onFailure(new IllegalArgumentException("No such algorithm: " + algorithmId));
return;
}
}
final String _algorithmId = algorithmId;
/*
* We're not a threaded listener so we're running on the search thread. This
* is good -- we don't want to serve more clustering requests than we can handle
* anyway.
*/
final Controller controller = controllerSingleton.getController();
final Map<String, Object> processingAttrs = new HashMap<>();
Map<String, Object> requestAttrs = clusteringRequest.getAttributes();
if (requestAttrs != null) {
handleInputClustersSpec(requestAttrs);
processingAttrs.putAll(requestAttrs);
}
try {
CommonAttributesDescriptor.attributeBuilder(processingAttrs)
.documents(prepareDocumentsForClustering(clusteringRequest, response))
.query(clusteringRequest.getQueryHint());
final long tsClusteringStart = System.nanoTime();
final ProcessingResult result = AccessController.doPrivileged((PrivilegedAction<ProcessingResult>) () ->
controller.process(processingAttrs, _algorithmId));
final DocumentGroup[] groups = adapt(result.getClusters());
final long tsClusteringEnd = System.nanoTime();
final Map<String, String> info = new LinkedHashMap<>();
info.put(ClusteringActionResponse.Fields.Info.ALGORITHM,
algorithmId);
info.put(ClusteringActionResponse.Fields.Info.SEARCH_MILLIS,
Long.toString(TimeUnit.NANOSECONDS.toMillis(tsSearchEnd - tsSearchStart)));
info.put(ClusteringActionResponse.Fields.Info.CLUSTERING_MILLIS,
Long.toString(TimeUnit.NANOSECONDS.toMillis(tsClusteringEnd - tsClusteringStart)));
info.put(ClusteringActionResponse.Fields.Info.TOTAL_MILLIS,
Long.toString(TimeUnit.NANOSECONDS.toMillis(tsClusteringEnd - tsSearchStart)));
info.put(ClusteringActionResponse.Fields.Info.INCLUDE_HITS,
Boolean.toString(clusteringRequest.getIncludeHits()));
info.put(ClusteringActionResponse.Fields.Info.MAX_HITS,
clusteringRequest.getMaxHits() == Integer.MAX_VALUE ?
"" : Integer.toString(clusteringRequest.getMaxHits()));
// Trim search response's hits if we need to.
if (clusteringRequest.getMaxHits() != Integer.MAX_VALUE) {
response = filterMaxHits(response, clusteringRequest.getMaxHits());
}
listener.onResponse(new ClusteringActionResponse(response, groups, info));
} catch (ProcessingException e) {
// Log a full stack trace with all nested exceptions but only return
// ElasticSearchException exception with a simple String (otherwise
// clients cannot deserialize exception classes).
String message = "Search results clustering error: " + e.getMessage();
listener.onFailure(new ElasticsearchException(message));
logger.warn("Could not process clustering request.", e);
return;
}
}
});
}
@SuppressWarnings("unchecked")
private void handleInputClustersSpec(Map<String, Object> requestAttrs) {
// Handle the "special" attribute key "clusters", which Lingo3G recognizes as a request
// for incremental clustering. The structure of the input clusters must follow this xcontent
// schema:
//
// "clusters": [{}, {}, ...]
//
// with zero, one or more objects representing cluster labels inside:
//
// { "label": "cluster label",
// "subclusters": [{}, {}, ...] }
//
// There is very limited input validation; this feature is largerly undocumented and
// officially unsupported.
if (requestAttrs.containsKey("clusters")) {
requestAttrs.put("clusters", parseClusters((List<Object>) requestAttrs.get("clusters")));
}
}
@SuppressWarnings("unchecked")
private List<Cluster> parseClusters(List<Object> xcontentList) {
ArrayList<Cluster> result = new ArrayList<>();
for (Object xcontent : xcontentList) {
result.add(parseCluster((Map<String, Object>) xcontent));
}
return result;
}
@SuppressWarnings("unchecked")
private Cluster parseCluster(Map<String, Object> xcontent) {
String label = (String) xcontent.get("label");
Cluster cluster = new Cluster(label);
List<Object> subclusters = (List<Object>) xcontent.get("clusters");
if (subclusters != null) {
cluster = cluster.addSubclusters(parseClusters(subclusters));
}
return cluster;
}
protected SearchResponse filterMaxHits(SearchResponse response, int maxHits) {
// We will use internal APIs here for efficiency. The plugin has restricted explicit ES compatibility
// anyway. Alternatively, we could serialize/ filter/ deserialize JSON, but this seems simpler.
SearchHits allHits = response.getHits();
SearchHit[] trimmedHits = new SearchHit[Math.min(maxHits, allHits.getHits().length)];
System.arraycopy(allHits.getHits(), 0, trimmedHits, 0, trimmedHits.length);
InternalAggregations _internalAggregations = null;
if (response.getAggregations() != null) {
_internalAggregations = new InternalAggregations(toInternal(response.getAggregations().asList()));
}
SearchHits _searchHits =
new SearchHits(trimmedHits, allHits.getTotalHits(), allHits.getMaxScore());
SearchProfileShardResults _searchProfileShardResults = new SearchProfileShardResults(response.getProfileResults());
InternalSearchResponse _searchResponse =
new InternalSearchResponse(
_searchHits,
_internalAggregations,
response.getSuggest(),
_searchProfileShardResults,
response.isTimedOut(),
response.isTerminatedEarly(),
response.getNumReducePhases());
return new SearchResponse(
_searchResponse,
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getTookInMillis(),
response.getShardFailures());
}
private List<InternalAggregation> toInternal(List<Aggregation> list) {
List<InternalAggregation> t = new ArrayList<>(list.size());
for (Aggregation a : list) {
t.add((InternalAggregation) a);
}
return t;
}
protected DocumentGroup[] adapt(List<Cluster> clusters) {
DocumentGroup[] groups = new DocumentGroup[clusters.size()];
for (int i = 0; i < groups.length; i++) {
groups[i] = adapt(clusters.get(i));
}
return groups;
}
private DocumentGroup adapt(Cluster cluster) {
DocumentGroup group = new DocumentGroup();
group.setId(cluster.getId());
List<String> phrases = cluster.getPhrases();
group.setPhrases(phrases.toArray(new String[phrases.size()]));
group.setLabel(cluster.getLabel());
group.setScore(cluster.getScore());
group.setOtherTopics(cluster.isOtherTopics());
List<Document> documents = cluster.getDocuments();
String[] documentReferences = new String[documents.size()];
for (int i = 0; i < documentReferences.length; i++) {
documentReferences[i] = documents.get(i).getStringId();
}
group.setDocumentReferences(documentReferences);
List<Cluster> subclusters = cluster.getSubclusters();
subclusters = (subclusters == null ? Collections.emptyList() : subclusters);
group.setSubgroups(adapt(subclusters));
return group;
}
/**
* Map {@link SearchHit} fields to logical fields of Carrot2 {@link Document}.
*/
private List<Document> prepareDocumentsForClustering(
final ClusteringActionRequest request,
SearchResponse response) {
SearchHit[] hits = response.getHits().getHits();
List<Document> documents = new ArrayList<>(hits.length);
List<FieldMappingSpec> fieldMapping = request.getFieldMapping();
StringBuilder title = new StringBuilder();
StringBuilder content = new StringBuilder();
StringBuilder url = new StringBuilder();
StringBuilder language = new StringBuilder();
boolean emptySourceWarningEmitted = false;
for (SearchHit hit : hits) {
// Prepare logical fields for each hit.
title.setLength(0);
content.setLength(0);
url.setLength(0);
language.setLength(0);
Map<String, SearchHitField> fields = hit.getFields();
Map<String, HighlightField> highlightFields = hit.getHighlightFields();
Map<String, Object> sourceAsMap = null;
for (FieldMappingSpec spec : fieldMapping) {
// Determine the content source.
Object appendContent = null;
outer:
switch (spec.source) {
case FIELD:
SearchHitField searchHitField = fields.get(spec.field);
if (searchHitField != null) {
appendContent = searchHitField.getValue();
}
break;
case HIGHLIGHT:
HighlightField highlightField = highlightFields.get(spec.field);
if (highlightField != null) {
appendContent = join(Arrays.asList(highlightField.fragments()));
}
break;
case SOURCE:
if (sourceAsMap == null) {
if (!hit.hasSource()) {
if (!emptySourceWarningEmitted) {
emptySourceWarningEmitted = true;
logger.warn("_source field mapping used but no source available for: {}, field {}",
hit.getId(),
spec.field);
}
} else {
sourceAsMap = hit.getSourceAsMap();
}
}
if (sourceAsMap != null) {
String[] fieldNames = spec.field.split("\\.");
Object value = sourceAsMap;
// Descend into maps.
for (String fieldName : fieldNames) {
if (Map.class.isInstance(value)) {
value = ((Map<?, ?>) value).get(fieldName);
if (value == null) {
// No such key.
logger.warn("Cannot find into field {} from spec: {}",
fieldName,
spec.field);
break outer;
}
} else {
logger.warn("Field is not a map: {} in spec.: {}",
fieldName,
spec.field);
break outer;
}
}
if (value instanceof List) {
appendContent = join((List<?>) value);
} else {
appendContent = value;
}
}
break;
default:
throw org.carrot2.elasticsearch.Preconditions.unreachable();
}
// Determine the target field.
if (appendContent != null) {
StringBuilder target;
switch (spec.logicalField) {
case URL:
url.setLength(0); // Clear previous (single mapping allowed).
target = url;
break;
case LANGUAGE:
language.setLength(0); // Clear previous (single mapping allowed);
target = language;
break;
case TITLE:
target = title;
break;
case CONTENT:
target = content;
break;
default:
throw org.carrot2.elasticsearch.Preconditions.unreachable();
}
// Separate multiple fields with a single dot (prevent accidental phrase gluing).
if (appendContent != null) {
if (target.length() > 0) {
target.append(" . ");
}
target.append(appendContent);
}
}
}
LanguageCode langCode = null;
if (language.length() > 0) {
String langCodeString = language.toString();
langCode = LanguageCode.forISOCode(langCodeString);
if (langCode == null && langCodeWarnings.add(langCodeString)) {
logger.warn("Language mapping not a supported ISO639-1 code: {}", langCodeString);
}
}
Document doc = new Document(
title.toString(),
content.toString(),
url.toString(),
langCode,
hit.getId());
documents.add(doc);
}
return documents;
}
static String join(List<?> list) {
StringBuilder sb = new StringBuilder();
for (Object t : list) {
if (sb.length() > 0) {
sb.append(" . ");
}
sb.append(t != null ? t.toString() : "");
}
return sb.toString();
}
private final class TransportHandler implements TransportRequestHandler<ClusteringActionRequest> {
@Override
public void messageReceived(final ClusteringActionRequest request, final TransportChannel channel) throws Exception {
execute(request, new ActionListener<ClusteringActionResponse>() {
@Override
public void onResponse(ClusteringActionResponse response) {
try {
channel.sendResponse(response);
} catch (Exception e) {
onFailure(e);
}
}
@Override
public void onFailure(Exception e) {
try {
channel.sendResponse(e);
} catch (Exception e1) {
logger.warn("Failed to send error response for action ["
+ ClusteringAction.NAME + "] and request [" + request + "]", e1);
}
}
});
}
}
}
/**
* An {@link BaseRestHandler} for {@link ClusteringAction}.
*/
public static class RestClusteringAction extends BaseRestHandler {
/**
* Action name suffix.
*/
public static String NAME = "_search_with_clusters";
public RestClusteringAction(
Settings settings,
RestController controller) {
super(settings);
controller.registerHandler(POST, "/" + NAME, this);
controller.registerHandler(POST, "/{index}/" + NAME, this);
controller.registerHandler(POST, "/{index}/{type}/" + NAME, this);
controller.registerHandler(GET, "/" + NAME, this);
controller.registerHandler(GET, "/{index}/" + NAME, this);
controller.registerHandler(GET, "/{index}/{type}/" + NAME, this);
}
@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
// A POST request must have a body.
if (request.method() == POST && !request.hasContent()) {
return channel -> emitErrorResponse(channel, logger,
new IllegalArgumentException("Request body was expected for a POST request."));
}
// Contrary to ES's default search handler we will not support
// GET requests with a body (this is against HTTP spec guidance
// in my opinion -- GET requests should be URL-based).
if (request.method() == GET && request.hasContent()) {
return channel -> emitErrorResponse(channel, logger,
new IllegalArgumentException("Request body was unexpected for a GET request."));
}
// Build an action request with data from the request.
// Parse incoming arguments depending on the HTTP method used to make
// the request.
final ClusteringActionRequestBuilder actionBuilder = new ClusteringActionRequestBuilder(client);
SearchRequest searchRequest = new SearchRequest();
switch (request.method()) {
case POST:
searchRequest.indices(Strings.splitStringByCommaToArray(request.param("index")));
searchRequest.types(Strings.splitStringByCommaToArray(request.param("type")));
actionBuilder.setSearchRequest(searchRequest);
actionBuilder.setSource(request.content(), request.getXContentType(), request.getXContentRegistry());
break;
case GET:
RestSearchAction.parseSearchRequest(searchRequest, request, null);
actionBuilder.setSearchRequest(searchRequest);
fillFromGetRequest(actionBuilder, request);
break;
default:
throw org.carrot2.elasticsearch.Preconditions.unreachable();
}
// Dispatch clustering request.
return channel -> client.execute(ClusteringAction.INSTANCE, actionBuilder.request(),
new ActionListener<ClusteringActionResponse>() {
@Override
public void onResponse(ClusteringActionResponse response) {
try {
XContentBuilder builder = channel.newBuilder();
builder.startObject();
response.toXContent(builder, request);
builder.endObject();
channel.sendResponse(
new BytesRestResponse(
response.getSearchResponse().status(),
builder));
} catch (Exception e) {
logger.debug("Failed to emit response.", e);
onFailure(e);
}
}
@Override
public void onFailure(Exception e) {
emitErrorResponse(channel, logger, e);
}
});
}
private static final EnumMap<LogicalField, String> GET_REQUEST_FIELDMAPPERS;
static {
GET_REQUEST_FIELDMAPPERS = new EnumMap<>(LogicalField.class);
for (LogicalField lf : LogicalField.values()) {
GET_REQUEST_FIELDMAPPERS.put(lf, "field_mapping_" + lf.name().toLowerCase(Locale.ROOT));
}
}
/**
* Extract and parse HTTP GET parameters for the clustering request.
*/
private void fillFromGetRequest(
ClusteringActionRequestBuilder actionBuilder,
RestRequest request) {
// Use the search query as the query hint, if explicit query hint
// is not available.
if (request.hasParam("query_hint")) {
actionBuilder.setQueryHint(request.param("query_hint"));
} else {
actionBuilder.setQueryHint(request.param("q"));
}
// Algorithm.
if (request.hasParam("algorithm")) {
actionBuilder.setAlgorithm(request.param("algorithm"));
}
// include_hits
if (request.hasParam("include_hits")) {
actionBuilder.setIncludeHits(request.param("include_hits"));
}
// max_hits
if (request.hasParam("max_hits")) {
actionBuilder.setMaxHits(request.param("max_hits"));
}
// Field mappers.
for (Map.Entry<LogicalField, String> e : GET_REQUEST_FIELDMAPPERS.entrySet()) {
if (request.hasParam(e.getValue())) {
for (String spec : Strings.splitStringByCommaToArray(request.param(e.getValue()))) {
actionBuilder.addFieldMappingSpec(spec, e.getKey());
}
}
}
}
}
}